{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Object Detection Models in SageMaker with Augmented Manifests\n",
    "\n",
    "This notebook demonstrates the use of an \"augmented manifest\" to train an object detection machine learning model with AWS SageMaker."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Here we define S3 file paths for input and output data, the training image containing the semantic segmentaion algorithm, and instantiate a SageMaker session."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import re\n",
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "import time\n",
    "from time import gmtime, strftime\n",
    "import json\n",
    "\n",
    "role = get_execution_role()\n",
    "sess = sagemaker.Session()\n",
    "s3 = boto3.resource('s3')\n",
    "\n",
    "training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'object-detection', repo_version='latest')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Required Inputs\n",
    "\n",
    "*Be sure to edit the file names and paths below for your own use!*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmented_manifest_filename_train = 'augmented-manifest-train.manifest' # Replace with the filename for your training data.\n",
    "augmented_manifest_filename_validation = 'augmented-manifest-validation.manifest' # Replace with the filename for your validation data.\n",
    "bucket_name = \"ground-truth-augmented-manifest-demo\" # Replace with your bucket name.\n",
    "s3_prefix = 'object-detection' # Replace with the S3 prefix where your data files reside.\n",
    "s3_output_path = 's3://{}/output'.format(bucket_name) # Replace with your desired output directory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The setup section concludes with a few more definitions and constants."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defines paths for use in the training job request.\n",
    "s3_train_data_path = 's3://{}/{}/{}'.format(bucket_name, s3_prefix, augmented_manifest_filename_train)\n",
    "s3_validation_data_path = 's3://{}/{}/{}'.format(bucket_name, s3_prefix, augmented_manifest_filename_validation)\n",
    "\n",
    "print(\"Augmented manifest for training data: {}\".format(s3_train_data_path))\n",
    "print(\"Augmented manifest for validation data: {}\".format(s3_validation_data_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Understanding the Augmented Manifest format"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Augmented manifests provide two key benefits. First, the format is consistent with that of a labeling job output manifest. This means that you can take your output manifests from a Ground Truth labeling job and, whether the dataset objects were entirely human-labeled, entirely machine-labeled, or anything in between, and use them as inputs to SageMaker training jobs - all without any additional translation or reformatting! Second, the dataset objects and their corresponding ground truth labels/annotations are captured *inline*. This effectively reduces the required number of channels by half, since you no longer need one channel for the dataset objects alone and another for the associated ground truth labels/annotations.\n",
    "\n",
    "The augmented manifest format is essentially the [json-lines format](http://jsonlines.org/), also called the new-line delimited JSON format. This format consists of an arbitrary number of well-formed, fully-defined JSON objects, each on a separate line. Augmented manifests must contain a field that defines a dataset object, and a field that defines the corresponding annotation. Let's look at an example for an object detection problem.\n",
    "\n",
    "The Ground Truth output format is discussed more fully for various types of labeling jobs in the [official documenation](https://docs.aws.amazon.com/sagemaker/latest/dg/sms-data-output.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{<span style=\"color:blue\">\"source-ref\"</span>: \"s3://bucket_name/path_to_a_dataset_object.jpeg\", <span style=\"color:blue\">\"labeling-job-name\"</span>: {\"annotations\":[{\"class_id\":\"0\",`<bounding box dimensions>`}],\"image_size\":[{`<image size simensions>`}]}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first field will always be either `source` our `source-ref`. This defines an individual dataset object. The name of the second field depends on whether the labeling job was created from the SageMaker console or through the Ground Truth API. If the job was created through the console, then the name of the field will be the labeling job name. Alternatively, if the job was created through the API, then this field maps to the `LabelAttributeName` parameter in the API. \n",
    "\n",
    "The training job request requires a parameter called `AttributeNames`. This should be a two-element list of strings, where the first string is \"source-ref\", and the second string is the label attribute name from the augmented manifest. This corresponds to the <span style=\"color:blue\">blue text</span> in the example above. In this case, we would define `attribute_names = [\"source-ref\", \"labeling-job-name\"]`.\n",
    "\n",
    "*Be sure to carefully inspect your augmented manifest so that you can define the `attribute_names` variable below.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preview Input Data\n",
    "\n",
    "Let's read the augmented manifest so we can inspect its contents to better understand the format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmented_manifest_s3_key = s3_train_data_path.split(bucket_name)[1][1:]\n",
    "s3_obj = s3.Object(bucket_name, augmented_manifest_s3_key)\n",
    "augmented_manifest = s3_obj.get()['Body'].read().decode('utf-8')\n",
    "augmented_manifest_lines = augmented_manifest.split('\\n')\n",
    "\n",
    "num_training_samples = len(augmented_manifest_lines) # Compute number of training samples for use in training job request.\n",
    "\n",
    "\n",
    "print('Preview of Augmented Manifest File Contents')\n",
    "print('-------------------------------------------')\n",
    "print('\\n')\n",
    "\n",
    "for i in range(2):\n",
    "    print('Line {}'.format(i+1))\n",
    "    print(augmented_manifest_lines[i])\n",
    "    print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The key feature of the augmented manifest is that it has both the data object itself (i.e., the image), and the annotation in-line in a single JSON object. Note that the `annotations` keyword contains dimensions and coordinates (e.g., width, top, height, left) for bounding boxes! The augmented manifest can contain an arbitrary number of lines, as long as each line adheres to this format.\n",
    "\n",
    "Let's discuss this format in more detail by descibing each parameter of this JSON object format.\n",
    "\n",
    "* The `source-ref` field defines a single dataset object, which in this case is an image over which bounding boxes should be drawn. Note that the name of this field is arbitrary. \n",
    "* The `object-detection-job-name` field defines the ground truth bounding box annotations that pertain to the image identified in the `source-ref` field. As mentioned above, note that the name of this field is arbitrary. You must take care to define this field in the `AttributeNames` parameter of the training job request, as shown later on in this notebook.\n",
    "* Because this example augmented manifest was generated through a Ground Truth labeling job, this example also shows an additional field called `object-detection-job-name-metadata`. This field contains various pieces of metadata from the labeling job that produced the bounding box annotation(s) for the associated image, e.g., the creation date, confidence scores for the annotations, etc. This field is ignored during the training job. However, to make it as easy as possible to translate Ground Truth labeling jobs into trained SageMaker models, it is safe to include this field in the augmented manifest you supply to the training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attribute_names = [\"source-ref\",\"XXXX\"] # Replace as appropriate for your augmented manifest."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Training Job\n",
    "\n",
    "First, we'll construct the request for the training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    if attribute_names == [\"source-ref\",\"XXXX\"]:\n",
    "        raise Exception(\"The 'attribute_names' variable is set to default values. Please check your augmented manifest file for the label attribute name and set the 'attribute_names' variable accordingly.\")\n",
    "except NameError:\n",
    "    raise Exception(\"The attribute_names variable is not defined. Please check your augmented manifest file for the label attribute name and set the 'attribute_names' variable accordingly.\")\n",
    "\n",
    "# Create unique job name \n",
    "job_name_prefix = 'groundtruth-augmented-manifest-demo'\n",
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "job_name = job_name_prefix + timestamp\n",
    "\n",
    "training_params = \\\n",
    "{\n",
    "    \"AlgorithmSpecification\": {\n",
    "        \"TrainingImage\": training_image, # NB. This is one of the named constants defined in the first cell.\n",
    "        \"TrainingInputMode\": \"Pipe\"\n",
    "    },\n",
    "    \"RoleArn\": role,\n",
    "    \"OutputDataConfig\": {\n",
    "        \"S3OutputPath\": s3_output_path\n",
    "    },\n",
    "    \"ResourceConfig\": {\n",
    "        \"InstanceCount\": 1,   \n",
    "        \"InstanceType\": \"ml.p3.2xlarge\",\n",
    "        \"VolumeSizeInGB\": 50\n",
    "    },\n",
    "    \"TrainingJobName\": job_name,\n",
    "    \"HyperParameters\": { # NB. These hyperparameters are at the user's discretion and are beyond the scope of this demo.\n",
    "         \"base_network\": \"resnet-50\",\n",
    "         \"use_pretrained_model\": \"1\",\n",
    "         \"num_classes\": \"1\",\n",
    "         \"mini_batch_size\": \"1\",\n",
    "         \"epochs\": \"5\",\n",
    "         \"learning_rate\": \"0.001\",\n",
    "         \"lr_scheduler_step\": \"3,6\",\n",
    "         \"lr_scheduler_factor\": \"0.1\",\n",
    "         \"optimizer\": \"rmsprop\",\n",
    "         \"momentum\": \"0.9\",\n",
    "         \"weight_decay\": \"0.0005\",\n",
    "         \"overlap_threshold\": \"0.5\",\n",
    "         \"nms_threshold\": \"0.45\",\n",
    "         \"image_shape\": \"300\",\n",
    "         \"label_width\": \"350\",\n",
    "         \"num_training_samples\": str(num_training_samples)\n",
    "    },\n",
    "    \"StoppingCondition\": {\n",
    "        \"MaxRuntimeInSeconds\": 86400\n",
    "    },\n",
    "    \"InputDataConfig\": [\n",
    "        {\n",
    "            \"ChannelName\": \"train\",\n",
    "            \"DataSource\": {\n",
    "                \"S3DataSource\": {\n",
    "                    \"S3DataType\": \"AugmentedManifestFile\", # NB. Augmented Manifest\n",
    "                    \"S3Uri\": s3_train_data_path,\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "                    \"AttributeNames\": attribute_names # NB. This must correspond to the JSON field names in your augmented manifest.\n",
    "                }\n",
    "            },\n",
    "            \"ContentType\": \"application/x-recordio\",\n",
    "            \"RecordWrapperType\": \"RecordIO\",\n",
    "            \"CompressionType\": \"None\"\n",
    "        },\n",
    "        {\n",
    "            \"ChannelName\": \"validation\",\n",
    "            \"DataSource\": {\n",
    "                \"S3DataSource\": {\n",
    "                    \"S3DataType\": \"AugmentedManifestFile\", # NB. Augmented Manifest\n",
    "                    \"S3Uri\": s3_validation_data_path,\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "                    \"AttributeNames\": attribute_names # NB. This must correspond to the JSON field names in your augmented manifest.\n",
    "                }\n",
    "            },\n",
    "            \"ContentType\": \"application/x-recordio\",\n",
    "            \"RecordWrapperType\": \"RecordIO\",\n",
    "            \"CompressionType\": \"None\"\n",
    "        }\n",
    "    ]\n",
    "}\n",
    " \n",
    "print('Training job name: {}'.format(job_name))\n",
    "print('\\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we create the Amazon SageMaker training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = boto3.client(service_name='sagemaker')\n",
    "client.create_training_job(**training_params)\n",
    "\n",
    "# Confirm that the training job has started\n",
    "status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
    "print('Training job current status: {}'.format(status))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TrainingJobStatus = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
    "SecondaryStatus = client.describe_training_job(TrainingJobName=job_name)['SecondaryStatus']\n",
    "print(TrainingJobStatus, SecondaryStatus)\n",
    "while TrainingJobStatus !='Completed' and TrainingJobStatus!='Failed':\n",
    "    time.sleep(60)\n",
    "    TrainingJobStatus = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
    "    SecondaryStatus = client.describe_training_job(TrainingJobName=job_name)['SecondaryStatus']\n",
    "    print(TrainingJobStatus, SecondaryStatus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_info = client.describe_training_job(TrainingJobName=job_name)\n",
    "print(training_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conclusion\n",
    "\n",
    "That's it! Let's review what we've learned. \n",
    "* Augmented manifests are a new format that provide a seamless interface between Ground Truth labeling jobs and SageMaker training jobs. \n",
    "* In augmented manifests, you specify the dataset objects and the associated annotations in-line.\n",
    "* Be sure to pay close attention to the `AttributeNames` parameter in the training job request. The strings you specifuy in this field must correspond to those that are present in your augmented manifest."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
